import torch
import torch.nn as nn
import torch.optim as optim
import shapely.geometry as geometry
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
class BoxPredictor(nn.Module):
"""Predicts x, y, w, h for box
"""
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(4, 8),
nn.PReLU(),
nn.Linear(8, 4),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc(x)
def create_box(self, x: float, y: float, w: float, h: float) -> geometry.Polygon:
"""Create box given x, y, w, h
Args:
x (float): centroid x of box
y (float): centroid y of box
w (float): box width
h (float): box height
Returns:
geometry.Polygon: box
"""
x_min = x - w / 2
y_min = y - h / 2
x_max = x + w / 2
y_max = y + h / 2
return geometry.box(x_min, y_min, x_max, y_max)
class DIoU(nn.Module):
"""Computes Distance-Intersection over Union
"""
def __init__(self):
super().__init__()
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute DIoU loss
Args:
pred (torch.Tensor): box predicted my model
target (torch.Tensor): target box
Returns:
torch.Tensor: DIoU
"""
px, py, pw, ph = pred
tx, ty, tw, th = target
x1 = torch.max(px - pw / 2, tx - tw / 2)
y1 = torch.max(py - ph / 2, ty - th / 2)
x2 = torch.min(px + pw / 2, tx + tw / 2)
y2 = torch.min(py + ph / 2, ty + th / 2)
intersection = max(0, x2 - x1) * max(0, y2 - y1)
union = pw * ph + tw * th - intersection
iou = intersection / union
distance = torch.linalg.norm(pred[:2] - target[:2])
diagonal = torch.linalg.norm(
torch.tensor([min(px - pw / 2, tx - tw / 2), min(py - ph / 2, ty - th / 2)])
- torch.tensor([max(px + pw / 2, tx + tw / 2), max(py + ph / 2, ty + th / 2)])
)
diou = 1 - iou + (distance ** 2) / (diagonal ** 2)
return diou, iou, distance, diagonal
torch.manual_seed(777)
# Train
predictor = BoxPredictor()
loss_function = DIoU()
optimizer = optim.Adam(predictor.parameters())
epochs = 150
label_box = torch.tensor([1, 0, 3, 5]).float()
label_box_geometry = predictor.create_box(*label_box.tolist())
input_box = torch.tensor([8, -7.3, 1, 1]).float()
input_box_geometry = predictor.create_box(*input_box.tolist())
predictions = []
losses = []
ious = []
distances = []
diagonals = []
for epoch in range(epochs):
pred = predictor(input_box)
diou, iou, distance, diagonal = loss_function(pred, label_box)
optimizer.zero_grad()
diou.backward()
optimizer.step()
predictions.append(pred.detach().clone())
losses.append(diou.item())
ious.append(iou.item())
distances.append(distance.item())
diagonals.append(diagonal.item())
# Visualize
fig, ax = plt.subplots()
def animate(frame):
pred = predictions[frame]
pred_geometry = predictor.create_box(*pred.tolist())
distance = geometry.LineString([pred_geometry.centroid, label_box_geometry.centroid])
ax.clear()
ax.set_xlim(-10, 10)
ax.set_ylim(-10, 10)
ax.grid(True, alpha=0.3)
ax.plot(*input_box_geometry.exterior.xy, color="blue", label="input", linewidth=1)
ax.plot(*label_box_geometry.exterior.xy, color="green", label="ground-truth", linewidth=1)
ax.plot(*pred_geometry.exterior.xy, color="black", label="predicted", linewidth=1)
ax.plot(*distance.xy, color="red", label="distance", linewidth=1, linestyle="dotted")
ax.set_aspect("equal")
ax.legend(loc="upper left")
ax.set_title(f"Epoch: {frame + 1}, DIoU: {losses[frame]:.5f}, IoU: {ious[frame]:.5f}, d: {distances[frame]:.5f} \n", fontsize=9)
return ()
anim = animation.FuncAnimation(fig, animate, frames=epochs, interval=50, blit=True, repeat=False)
plt.close(fig)
HTML(anim.to_jshtml())